15e258
@@ -24,6 +24,8 @@
 import java.util.Set;
 import java.util.Stack;
 
+import com.google.common.collect.Sets;
+import org.apache.hadoop.hive.ql.exec.LateralViewForwardOperator;
 import org.apache.hadoop.hive.ql.exec.OperatorUtils;
 import org.apache.hadoop.hive.ql.exec.TableScanOperator;
 import org.slf4j.Logger;
@@ -80,7 +82,7 @@
       return null;
     }
 
-    LOG.info("Check if it can be converted to map join");
+    LOG.info("Check if operator " + joinOp + " can be converted to map join");
     long[] mapJoinInfo = getMapJoinConversionInfo(joinOp, context);
     int mapJoinConversionPos = (int) mapJoinInfo[0];
 
@@ -196,25 +198,40 @@
private int convertJoinBucketMapJoin(JoinOperator joinOp, MapJoinOperator mapJoi
     // max. This table is either the big table or we cannot convert.
     boolean bigTableFound = false;
     boolean useTsStats = context.getConf().getBoolean(HiveConf.ConfVars.SPARK_USE_TS_STATS_FOR_MAPJOIN.varname, false);
-    boolean hasUpstreamSinks = false;
 
-    // Check whether there's any upstream RS.
-    // If so, don't use TS stats because they could be inaccurate.
-    for (Operator<? extends OperatorDesc> parentOp : joinOp.getParentOperators()) {
-      Set<ReduceSinkOperator> parentSinks =
-          OperatorUtils.findOperatorsUpstream(parentOp, ReduceSinkOperator.class);
-      parentSinks.remove(parentOp);
-      if (!parentSinks.isEmpty()) {
-        hasUpstreamSinks = true;
+    // If we're using TS's stats for mapjoin optimization, check each branch and see if there's any
+    // upstream operator (e.g., JOIN, LATERAL_VIEW) that can increase output data size.
+    // If so, mark that branch as the big table branch.
+    if (useTsStats) {
+      LOG.debug("Checking map join optimization for operator {} using TS stats", joinOp);
+      for (Operator<? extends OperatorDesc> parentOp : joinOp.getParentOperators()) {
+        if (isBigTableBranch(parentOp)) {
+          if (bigTablePosition < 0 && bigTableCandidateSet.contains(pos)) {
+            LOG.debug("Found a big table branch with parent operator {} and position {}", parentOp, pos);
+            bigTablePosition = pos;
+            bigTableFound = true;
+            bigInputStat = new Statistics();
+            bigInputStat.setDataSize(Long.MAX_VALUE);
+          } else {
+            // Either we've found multiple big table branches, or the current branch cannot
+            // be a big table branch. Disable mapjoin for these cases.
+            LOG.debug("Cannot enable map join optimization for operator {}", joinOp);
+            return new long[]{-1, 0, 0};
+          }
+        }
+        pos++;
       }
     }
 
-    // If we are using TS stats and this JOIN has at least one upstream RS, disable MapJoin conversion.
-    if (useTsStats && hasUpstreamSinks) {
-      return new long[]{-1, 0, 0};
-    }
+    pos = 0;
 
     for (Operator<? extends OperatorDesc> parentOp : joinOp.getParentOperators()) {
+      // Skip the potential big table identified above
+      if (pos == bigTablePosition) {
+        pos++;
+        continue;
+      }
+
       Statistics currInputStat;
       if (useTsStats) {
         currInputStat = new Statistics();
@@ -255,9 +272,8 @@
private int convertJoinBucketMapJoin(JoinOperator joinOp, MapJoinOperator mapJoi
       }
 
       long inputSize = currInputStat.getDataSize();
-      if ((bigInputStat == null)
-          || ((bigInputStat != null)
-          && (inputSize > bigInputStat.getDataSize()))) {
+
+      if (bigInputStat == null || inputSize > bigInputStat.getDataSize()) {
 
         if (bigTableFound) {
           // cannot convert to map join; we've already chosen a big table
@@ -317,6 +333,25 @@
private int convertJoinBucketMapJoin(JoinOperator joinOp, MapJoinOperator mapJoi
     return new long[]{bigTablePosition, connectedMapJoinSize, totalSize};
   }
 
+  /**
+   * Check whether the branch starting from 'op' is a potential big table branch.
+   * This is true if the branch contains any operator that could potentially increase
+   * output data size, such as JOIN and LATERAL_VIEW. If this is the case, we assume
+   * the worst and mark the branch as big table branch in the MapJoin optimization.
+   *
+   * @return True if the branch starting at 'op' is a big table branch. False otherwise.
+   */
+  private boolean isBigTableBranch(Operator<? extends OperatorDesc> op) {
+    for (Class<? extends Operator<? extends OperatorDesc>> clazz :
+        Sets.newHashSet(JoinOperator.class, LateralViewForwardOperator.class)) {
+      Set<? extends Operator<? extends OperatorDesc>> parentSinks = OperatorUtils.findOperatorsUpstream(op, clazz);
+      if (!parentSinks.isEmpty()) {
+        return true;
+      }
+    }
+    return false;
+  }
+
   /**
    * Examines this operator and all the connected operators, for mapjoins that will be in the same work.
    * @param parentOp potential big-table parent operator, explore up from this.
